#!/usr/bin/env python
#=========================================================================
# proc-xcel-sim [options]
#=========================================================================
#
#  -h --help           Display this message
#
#  --proc-impl         {fl,cl,rtl}
#  --xcel-impl         {fl,cl,rtl,null}
#  --bmark <dataset>   {cksum-xcel, cksum}
#  --translate         Simulate translated and imported DUTs
#  --trace             Display line tracing
#  --limit             Set max number of cycles, default=100000
#
# Author : Shunning Jiang, Christopher Batten
# Date   : June 10, 2019

# Hack to add project root to python path

import argparse
import os
import struct
import sys

from pymtl3 import *
from pymtl3.stdlib.mem import MagicMemoryCL, mk_mem_msg
from pymtl3.stdlib.test_utils import TestSinkCL, TestSrcCL

# Hack to add project root to python path
sim_dir = os.path.dirname( os.path.abspath( __file__ ) )
while sim_dir:
  if os.path.exists( sim_dir + os.path.sep + "pytest.ini" ):
    sys.path.insert(0,sim_dir)
    break
  sim_dir = os.path.dirname(sim_dir)

from examples.ex03_proc.NullXcel import NullXcelRTL
from examples.ex03_proc.ProcCL import ProcCL
from examples.ex03_proc.ProcFL import ProcFL
from examples.ex03_proc.ProcRTL import ProcRTL
from examples.ex03_proc.SparseMemoryImage import SparseMemoryImage
from examples.ex03_proc.test.harness import TestHarness
from examples.ex03_proc.tinyrv0_encoding import assemble
from examples.ex03_proc.ubmark.proc_ubmark_cksum_roll import ubmark_cksum_roll
from examples.ex04_xcel.ChecksumXcelCL import ChecksumXcelCL
from examples.ex04_xcel.ChecksumXcelFL import ChecksumXcelFL
from examples.ex04_xcel.ChecksumXcelRTL import ChecksumXcelRTL
from examples.ex04_xcel.ProcXcel import ProcXcel
from examples.ex04_xcel.ubmark.proc_ubmark_cksum_xcel_roll import ubmark_cksum_xcel_roll

#=========================================================================
# Command line processing
#=========================================================================

class ArgumentParserWithCustomError(argparse.ArgumentParser):
  def error( self, msg = "" ):
    if ( msg ): print("\n"+f" ERROR: {msg}")
    print("")
    file = open( sys.argv[0] )
    for ( lineno, line ) in enumerate( file ):
      if ( line[0] != '#' ): sys.exit(msg != "")
      if ( (lineno == 2) or (lineno >= 4) ): print(line[1:].rstrip("\n"))

def parse_cmdline():
  p = ArgumentParserWithCustomError( add_help=False )

  # Standard command line arguments

  p.add_argument( "-h", "--help", action="store_true" )

  # Additional commane line arguments for the simulator

  p.add_argument( "--trace", action="store_true" )
  p.add_argument( "--proc-impl", default="rtl", choices=["fl", "cl", "rtl"] )
  p.add_argument( "--xcel-impl", default="rtl", choices=["fl", "cl", "rtl", "null"] )
  p.add_argument( "--translate", action="store_true" )
  p.add_argument( "--bmark", default="cksum-xcel",
                             choices=["cksum", "cksum-xcel"] )
  p.add_argument( "--limit", default=100000, type=int )

  opts = p.parse_args()
  if opts.help: p.error()
  return opts

proc_impl_dict = {
  "fl" : ProcFL,
  "cl" : ProcCL,
  "rtl": ProcRTL,
}
xcel_impl_dict = {
  "fl"   : ChecksumXcelFL,
  "cl"   : ChecksumXcelCL,
  "rtl"  : ChecksumXcelRTL,
  "null" : NullXcelRTL,
}

bmark_dict = {
  "cksum-xcel" : ubmark_cksum_xcel_roll,
  "cksum"      : ubmark_cksum_roll
}

class TestHarness(Component):

  #-----------------------------------------------------------------------
  # constructor
  #-----------------------------------------------------------------------

  def construct( s, ProcClass, XcelClass, dump_vcd,
                 src_delay, sink_delay,
                 mem_stall_prob, mem_latency ):
    s.commit_inst = OutPort( Bits1 )

    s.src  = TestSrcCL ( Bits32, [], src_delay, src_delay   )
    s.sink = TestSinkCL( Bits32, [], sink_delay, sink_delay )
    s.mem  = MagicMemoryCL  ( 2, latency = mem_latency )

    s.dut = m = ProcXcel( ProcClass, XcelClass )
    m.mngr2proc //= s.src.send
    m.proc2mngr //= s.sink.recv
    m.imem      //= s.mem.ifc[0]
    m.dmem      //= s.mem.ifc[1]

    m.commit_inst //= s.commit_inst

  #-----------------------------------------------------------------------
  # load
  #-----------------------------------------------------------------------

  def load( self, mem_image ):

    # Iterate over the sections

    sections = mem_image.get_sections()
    for section in sections:

      # For .mngr2proc sections, copy section into mngr2proc src

      if section.name == ".mngr2proc":
        for i in range(0,len(section.data),4):
          bits = struct.unpack_from("<I",memoryview(section.data)[i:i+4])[0]
          # self.src.src.msgs.append( Bits32(bits) )
          self.src.msgs.append( Bits32(bits) )

      # For .proc2mngr sections, copy section into proc2mngr_ref src

      elif section.name == ".proc2mngr":
        for i in range(0,len(section.data),4):
          bits = struct.unpack_from("<I",memoryview(section.data)[i:i+4])[0]
          # self.sink.sink.msgs.append( Bits32(bits) )
          self.sink.msgs.append( Bits32(bits) )

      # For all other sections, simply copy them into the memory

      else:
        self.mem.write_mem( section.addr, section.data )

  #-----------------------------------------------------------------------
  # done
  #-----------------------------------------------------------------------

  def done( s ):
    return s.src.done() and s.sink.done()

  #-----------------------------------------------------------------------
  # line_trace
  #-----------------------------------------------------------------------

  def line_trace( s ):
    return s.src.line_trace()  + " >" + \
           s.dut.line_trace() + \
           s.mem.line_trace()  + " > " + \
           s.sink.line_trace()

#=========================================================================
# Main
#=========================================================================

def main():
  opts = parse_cmdline()

  # Check if there are any conflicts in the given options

  # --translate can only be used on RTL proc and RTL/Null xcel
  if opts.translate:
    assert opts.proc_impl == "rtl", \
      "--translate option can only be used with RTL processor implementation!"
    assert opts.xcel_impl == "rtl" or opts.xcel_impl == "null", \
      "--translate option can only be used with NullXcel or RTL accelerator!"

  # If --xcel null is true, then only cksum is valid as bmark
  if opts.xcel_impl == 'null':
    assert opts.bmark == 'cksum', \
      "--xcel-impl null option can only be used with cksum bmark!"

  # Assemble the test program

  mem_image = bmark_dict[ opts.bmark ].gen_mem_image()

  #-----------------------------------------------------------------------
  # Setup simulator
  #-----------------------------------------------------------------------

  # Create test harness and elaborate

  print()
  print("----- Proc:", opts.proc_impl.upper(), "-"*50 , "Xcel:", opts.xcel_impl.upper(), "-----")
  print()

  model = TestHarness( proc_impl_dict[ opts.proc_impl ],
                       xcel_impl_dict[ opts.xcel_impl ], 0,
                       # src  sink  memstall  memlat
                         0,   0,    0,        1 )

  # Apply translation pass and import pass if required

  if opts.translate:
    from pymtl3.passes.backends.yosys import YosysTranslationImportPass
    model.elaborate()
    model.dut.set_metadata( YosysTranslationImportPass.enable, True )
    model = YosysTranslationImportPass()( model )

  model.apply( DefaultPassGroup(print_line_trace=opts.trace) )

  # Load the program into the model

  model.load( mem_image )

  #-----------------------------------------------------------------------
  # Run the simulation
  #-----------------------------------------------------------------------

  commit_inst = 0

  model.sim_reset()

  limit = opts.limit

  while not model.done() and model.sim_cycle_count() < limit:
    model.sim_tick()
    commit_inst += int(model.commit_inst)

  assert model.sim_cycle_count() < limit

  # Verify the results of simulation

  print()
  passed = bmark_dict[ opts.bmark ].verify( model.mem.mem.mem )
  print()

  if not passed:
    exit(1)

  # Display stats

  print( "  total_num_cycles      = {}".format( model.sim_cycle_count() ) )
  print( "  total_committed_insts = {}".format( commit_inst ) )
  print( "  CPI                   = {:1.2f}".format( model.sim_cycle_count()/float(commit_inst) ) )
  print()

  exit(0)

main()
